{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "655a660d-c302-423a-8054-21c77e87eae6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['OMP_NUM_THREADS'] = \"16\"\n",
    "os.environ['OMP_DISPLAY_ENV'] = \"TRUE\"\n",
    "os.environ['OMP_PROC_BIND'] = \"TRUE\"\n",
    "os.environ['OMP_SCHEDULE'] = \"STATIC\"\n",
    "os.environ['GOMP_CPU_AFFINITY'] = \"0-15\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a20732e-6c90-418a-9817-4c1b4e2a2077",
   "metadata": {
    "id": "5a20732e-6c90-418a-9817-4c1b4e2a2077",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import logging\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch._dynamo\n",
    "import torch._inductor\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torchvision import models\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "from torch.profiler import profile, record_function, ProfilerActivity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11c80752-d9c6-455c-88c4-130db1688134",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(torch.__config__.parallel_info())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f92994f-ffe8-4e80-9c46-8963aade1246",
   "metadata": {
    "id": "0f92994f-ffe8-4e80-9c46-8963aade1246",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def build_data_loader(data_dir, batch_size, random_seed=42, valid_size=0.1, shuffle=True, test=False):\n",
    "    transform = transforms.Compose([transforms.ToTensor()])\n",
    "    transform.crop_size=224\n",
    "    transform.resize_size=224\n",
    "    \n",
    "    train_dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)\n",
    "    test_dataset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)\n",
    "    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle)\n",
    "\n",
    "    return (train_loader, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d745f184-7704-4850-9f57-4268cf20102e",
   "metadata": {
    "id": "d745f184-7704-4850-9f57-4268cf20102e",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train(model, data_loader, num_epochs, criterion, optimizer, device):\n",
    "    total_steps = len(train_loader)\n",
    "    for epoch in range(num_epochs):\n",
    "        start = time.time()\n",
    "        for step, (images, labels) in enumerate(train_loader):  \n",
    "            # Move tensors to the configured device\n",
    "            images = images.to(device)\n",
    "            labels = labels.to(device)\n",
    "        \n",
    "            # Forward pass\n",
    "            outputs = model(images)\n",
    "            outputs.to(device)\n",
    "            loss = criterion(outputs, labels)\n",
    "        \n",
    "            # Backward and optimize\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        end = time.time()\n",
    "               \n",
    "        print ('Epoch [{}/{}], Loss: {:.4f}, time: {} seconds'.format(epoch+1, num_epochs, loss.item(), int(end-start)))\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1843b38-0eff-476b-8ed0-2a1646dbb72f",
   "metadata": {
    "id": "e1843b38-0eff-476b-8ed0-2a1646dbb72f",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def test(model, test_loader, device):\n",
    "    with torch.no_grad():\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        for images, labels in test_loader:\n",
    "            images = images.to(device)\n",
    "            labels = labels.to(device)\n",
    "            outputs = model(images)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "            del images, labels, outputs\n",
    "\n",
    "    print('Accuracy of the network on the {} test images: {} %'.format(10000, 100 * correct / total))  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "348a57dd-8fb2-4e0a-8407-bd9dfb282605",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self, num_classes=10):\n",
    "        super(CNN, self).__init__()\n",
    "        self.layer1 = nn.Sequential(\n",
    "            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size = 2, stride = 2))\n",
    "        \n",
    "        self.layer2 = nn.Sequential(\n",
    "            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size = 2, stride = 2))\n",
    "\n",
    "        self.fc1 = nn.Linear(4096, 512)\n",
    "        \n",
    "        self.fc2 = nn.Linear(512, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.layer1(x)\n",
    "        out = self.layer2(out)\n",
    "        out = out.reshape(out.size(0), -1)\n",
    "        out = self.fc1(out)\n",
    "        out = self.fc2(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76056f95-3978-44e4-94a0-a9c4c7ee21fc",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "76056f95-3978-44e4-94a0-a9c4c7ee21fc",
    "outputId": "1066a5d6-c889-402e-864c-7d7359d1ef91",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# General parameters\n",
    "data_dir = '/tmp'\n",
    "device = \"cpu\"\n",
    "\n",
    "# Hyperparameters\n",
    "lr = 0.00001\n",
    "weight_decay = 0.005\n",
    "batch_size = 64\n",
    "num_epochs = 10\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam\n",
    "\n",
    "# FashionMNIST dataset \n",
    "train_loader, test_loader = build_data_loader(data_dir=data_dir, batch_size=batch_size)\n",
    "\n",
    "# Models\n",
    "model = CNN().to(device)\n",
    "\n",
    "# Optimizer\n",
    "optimizer = optimizer(model.parameters(), lr, weight_decay=weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db1b0241-6c87-40cf-8e55-d8cddf5cc71e",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 537
    },
    "id": "db1b0241-6c87-40cf-8e55-d8cddf5cc71e",
    "outputId": "af1fa5de-b53a-4c5b-edb5-4c202ed455c3",
    "tags": []
   },
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "train(model, train_loader, num_epochs, criterion, optimizer, device)\n",
    "end = time.time()\n",
    "print('Training time: {} seconds'.format(int(end-start)))\n",
    "test(model, test_loader, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e794d3a5-5955-4657-b191-b4b2c2b71b8b",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "e794d3a5-5955-4657-b191-b4b2c2b71b8b",
    "outputId": "5f53d802-4aeb-4f0d-f555-3db2ac613ae4",
    "tags": []
   },
   "outputs": [],
   "source": [
    "activities = [ProfilerActivity.CPU]\n",
    "prof = profile(activities=activities)\n",
    "        \n",
    "input_sample, _ = next(iter(train_loader))\n",
    "\n",
    "prof.start()\n",
    "train(model, train_loader, num_epochs, criterion, optimizer, device)\n",
    "prof.stop()\n",
    "\n",
    "print(prof.key_averages().table(sort_by=\"self_cpu_time_total\", row_limit=10))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
